Add initial inference data filtering function#621
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds UV packaging config to pyproject.toml. Updates preprocess.filter_equity_bars to accept configurable thresholds and adds tests. Enhances libraries/python/src/internal/tft_dataset.py to clone and validate inputs with a new dataset_schema, switch timestamps to UTC, generate calendar features, and cast time_idx and new features to Int64. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant C as Caller
participant P as preprocess.filter_equity_bars
C->>P: data (pl.DataFrame)
activate P
P->>P: clone input dataframe
P->>P: group_by("ticker") → agg(avg_close_price, avg_volume)
P->>P: filter(avg_close_price > minimum_average_close_price and avg_volume > minimum_average_volume)
P-->>C: filtered per-ticker DataFrame
deactivate P
sequenceDiagram
autonumber
participant U as Upstream Loader
participant T as TFTDataset.__init__
participant S as dataset_schema
participant FE as FeatureEngineering
participant SC as Scaler
U->>T: raw data
activate T
T->>T: clone data (avoid mutation)
T->>S: dataset_schema.validate(data)
S-->>T: validated data
rect rgba(220,235,255,0.5)
note right of FE: timezone → UTC\ncompute day_of_week, day_of_month,\nday_of_year, month, year\ncast calendar features & time_idx → Int64
T->>FE: apply timezone & compute features
FE-->>T: augmented/cast dataset
end
T->>SC: scaling/processing on validated data
SC-->>T: scaled dataset
T-->>U: initialized TFTDataset
deactivate T
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45–70 minutes Poem
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
This stack of pull requests is managed by Graphite. Learn more about stacking. |
There was a problem hiding this comment.
Pull Request Overview
This PR adds initial data filtering functionality and refactors the TFT dataset processing. It introduces a new filter_equity_bars function to filter financial data based on average close price and volume thresholds, along with schema validation using pandera.
- Adds
filter_equity_barsfunction with configurable price and volume thresholds - Introduces data schema validation using pandera for equity bar data
- Refactors TFT dataset timezone handling and type casting for consistency
Reviewed Changes
Copilot reviewed 4 out of 5 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| libraries/python/src/internal/tft_dataset.py | Adds schema validation, changes timezone to UTC, improves type consistency with Int64 casting |
| applications/portfoliomanager/src/portfoliomanager/preprocess.py | Implements new filtering function for equity data based on price and volume thresholds |
| applications/portfoliomanager/tests/test_preprocess.py | Comprehensive test suite covering various filtering scenarios and edge cases |
| applications/portfoliomanager/pyproject.toml | Updates project configuration to support package structure |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Graphite Automations"Assign author to pull request" took an action on this PR • (08/22/25)1 assignee was added to this PR based on John Forstmeier's automation. |
… github.com:pocketsizefund/pocketsizefund into 08-22-add_initial_inference_data_filtering_function
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
applications/portfoliomanager/src/portfoliomanager/preprocess.py (1)
7-9: Make thresholds configurable via parameters (repeat of earlier feedback)Hardcoded thresholds make experimentation cumbersome. Thread them as optional parameters with sensible defaults to preserve current behavior.
-def filter_equity_bars(data: pl.DataFrame) -> pl.DataFrame: +def filter_equity_bars( + data: pl.DataFrame, + minimum_average_close_price: float = 10.0, + minimum_average_volume: float = 1_000_000.0, + *, + strict: bool = True, +) -> pl.DataFrame: data = data.clone() - minimum_average_close_price = 10.0 - minimum_average_volume = 1_000_000.0 + price_col = pl.col("avg_close_price") + vol_col = pl.col("avg_volume") + price_pred = price_col > minimum_average_close_price if strict else price_col >= minimum_average_close_price + vol_pred = vol_col > minimum_average_volume if strict else vol_col >= minimum_average_volume return ( data.group_by("ticker") .agg( avg_close_price=pl.col("close_price").mean(), avg_volume=pl.col("volume").mean(), ) - .filter( - (pl.col("avg_close_price") > minimum_average_close_price) - & (pl.col("avg_volume") > minimum_average_volume) - ) + .filter(price_pred & vol_pred) )
🧹 Nitpick comments (2)
applications/portfoliomanager/pyproject.toml (1)
8-11: Packaging enabled: good step; ensure direct runtime deps are declaredEnabling
[tool.uv].package = truemeansportfoliomanagercan be built/installed independently. Sinceportfoliomanager.preprocessimportspolars, declarepolarsas a direct dependency here instead of relying on transitive installs viainternalto prevent import errors for downstream consumers.Apply:
[project] name = "portfoliomanager" version = "0.1.0" description = "Portfolio prediction and construction service" -requires-python = "==3.12.10" -dependencies = ["internal"] +requires-python = ">=3.12,<3.13" +dependencies = [ + "internal", + "polars", # direct runtime dependency used by portfoliomanager.preprocess +]applications/portfoliomanager/tests/test_preprocess.py (1)
1-4: Minor nit: prefer DataFrame.height for row counts
len(df)works, butdf.heightreads clearer and avoids constructing an iterator in some backends. Optional.Example:
-assert len(result) == 1 +assert result.height == 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (4)
applications/portfoliomanager/pyproject.toml(1 hunks)applications/portfoliomanager/src/portfoliomanager/preprocess.py(1 hunks)applications/portfoliomanager/tests/test_preprocess.py(1 hunks)libraries/python/src/internal/tft_dataset.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
applications/portfoliomanager/src/portfoliomanager/preprocess.py (1)
applications/datamanager/src/datamanager/alpaca_client.py (1)
fetch_latest_data(64-170)
libraries/python/src/internal/tft_dataset.py (2)
libraries/python/src/internal/dataset.py (1)
TemporalFusionTransformerDataset(28-321)libraries/python/tests/test_equity_bar.py (9)
test_equity_bar_schema_type_coercion(180-198)test_equity_bar_schema_valid_data(7-22)test_equity_bar_schema_zero_timestamp(79-94)test_equity_bar_schema_negative_timestamp(61-76)test_equity_bar_schema_multiple_rows(237-253)test_equity_bar_schema_ticker_uppercase_passes(43-58)test_equity_bar_schema_negative_volume(144-159)test_equity_bar_schema_ticker_lowercase_fails(25-40)test_equity_bar_schema_missing_required_column(201-216)
applications/portfoliomanager/tests/test_preprocess.py (1)
applications/portfoliomanager/src/portfoliomanager/preprocess.py (1)
filter_equity_bars(4-20)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Run Python quality checks
- GitHub Check: Run Python tests
🔇 Additional comments (9)
applications/portfoliomanager/src/portfoliomanager/preprocess.py (1)
4-20: Solid, side-effect-free filtering; aligns with testsCloning the input for immutability and using per-ticker aggregates with strict “greater-than” thresholds is clear and easy to reason about. The returned schema (ticker, avg_close_price, avg_volume) matches the tests.
libraries/python/src/internal/tft_dataset.py (5)
30-31: Good call to clone the inputDefensive copy avoids mutating upstream dataframes during feature engineering.
147-154: Explicit Int64 casts for calendar features improve dtype stabilityCasting calendar columns to Int64 ensures deterministic dtypes across platforms and Polars versions.
165-166: Schema validation placed pre-scaling: goodValidating full, engineered-frame dtypes before scaling/mapping will surface data issues early and keeps the scaler’s inputs consistent.
140-143: Timezone change to UTC requires downstream validationThe recent switch in
tft_dataset.pyfrom using America/New_York–derived midnight to UTC for timestamp normalization may shift epoch seconds by up to 5 hours (depending on DST), which can impact any downstream logic or consumers built around NY-local dates. Please review the following areas where America/New_York–based timestamps are explicitly used and confirm that their behavior remains correct:• libraries/python/src/internal/cloud_event.py (lines 16–20, 33–37): event “date” fields generated with
ZoneInfo("America/New_York").
• libraries/python/src/internal/dates.py (lines 9–13, 21–25): default_factory and parser.replace(tzinfo=ZoneInfo("America/New_York")).
• libraries/python/tests/test_dates.py (lines 11–15): tests mockingdatetime.now(tz=ZoneInfo("America/New_York")).
• applications/models/src/models/combine_tft_data_sources.py (lines 40–44): output filename timestamp uses NY timezone.
• applications/models/src/models/train_tft_model.py (lines 20–30, 44–53, 87–95, 100–108, 123–131, 141–142): logging runtimes and run names all useZoneInfo("America/New_York").
• applications/models/src/models/get_alpaca_equity_bars.py (lines 61–63):end = datetime.now(tz=ZoneInfo("America/New_York")).
• applications/datamanager/src/datamanager/alpaca_client.py (lines 102–114): converts naive timestamps to UTC then compares dates in NY to determine bar grouping.
• applications/datamanager/tests/test_alpaca_client.py (lines 36–44, 72–76, 130–138): tests assume bar timestamps at NY 16:00 local.Double-check each consumer’s use of epoch seconds or date comparisons to ensure no unintended shifts occur now that
tft_datasetuses UTC normalization.
156-163: time_idx bump to Int64 is safe — no Int32 consumers foundI’ve verified that:
- The only definitions and uses of
time_idxare withintft_dataset.py, and it’s now consistently cast topl.Int64.- The Pandera schema defines
"time_idx": pa.Column(int, required=True), which maps to Python’sint(i.e. NumPy/Pandas int64) rather than a 32-bit type.No other code paths serialize or consume
time_idxas a 32-bit integer. You can safely keep the Int64 change.applications/portfoliomanager/tests/test_preprocess.py (3)
6-25: Tests hit the intended grouping/aggregation behaviorCovers strict “greater-than” semantics and validates both aggregate values and output schema. Nice.
73-114: Boundary and near-boundary coverage is thoroughExact-threshold exclusion and “just above” inclusion guard against off-by-one mistakes.
197-214: Immutability test is valuableConfirms the clone pattern is honored and protects callers from accidental mutation.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
libraries/python/src/internal/tft_dataset.py (2)
103-121: Bug: weekday/weekend holiday mapping is inverted for null rowsThe logic sets
is_holiday=Trueon weekdays (Mon–Fri) andFalseon weekends for rows whereis_holidayis null. That’s the opposite of typical market calendars. This will mislabel most business days as holidays and weekends as non-holidays, skewing categorical signals and training targets.Proposed fix (invert the branches):
- .then(True) # noqa: FBT003 + .then(False) # weekdays are not holidays ... - .then(False) # noqa: FBT003 + .then(True) # weekends are holidaysIf you intend “holiday” to mean “non-trading day,” this inversion is required. If you intend “holiday” to specifically mean exchange-observed holidays (subset of weekdays), consider driving this from an exchange calendar (e.g., pandas-market-calendars) rather than heuristics.
126-145: Avoid filling OHLC/VWAP with zeros; prefer forward-fill (per ticker) and zero volume only on non-trading daysFilling prices with 0.0 creates non-physical values and teaches the model that prices collapse to zero on missing days. Forward-filling prices for non-trading days is a better default; volume can be zero on those days.
Example refactor (preserves your null-timestamp fallback; forward-fills per ticker after sorting):
- data = data.with_columns( - [ - pl.col("open_price").fill_null(0.0), - pl.col("high_price").fill_null(0.0), - pl.col("low_price").fill_null(0.0), - pl.col("close_price").fill_null(0.0), - pl.col("volume").fill_null(0.0), - pl.col("volume_weighted_average_price").fill_null(0.0), - pl.col("sector").fill_null("Not Available"), - pl.col("industry").fill_null("Not Available"), - pl.col("ticker").fill_null("UNKNOWN"), - pl.col("timestamp").fill_null( - pl.col("date") - .cast(pl.Datetime) - .dt.replace_time_zone("UTC") - .cast(pl.Int64) - .floordiv(1000) - ), - ] - ) + data = data.sort(["ticker", "date"]).with_columns( + [ + # forward-fill prices within each ticker + pl.col("open_price").fill_null(strategy="forward").over("ticker"), + pl.col("high_price").fill_null(strategy="forward").over("ticker"), + pl.col("low_price").fill_null(strategy="forward").over("ticker"), + pl.col("close_price").fill_null(strategy="forward").over("ticker"), + # zero volume on non-trading days (still nulls after join) + pl.col("volume").fill_null(0.0), + pl.col("volume_weighted_average_price").fill_null( + strategy="forward" + ).over("ticker"), + pl.col("sector").fill_null("Not Available"), + pl.col("industry").fill_null("Not Available"), + pl.col("ticker").fill_null("UNKNOWN"), + pl.col("timestamp").fill_null( + pl.col("date") + .cast(pl.Datetime) + .dt.replace_time_zone("UTC") + .dt.epoch("ms") + ), + ] + )This prevents artificial price shocks while keeping volume semantics reasonable.
🧹 Nitpick comments (4)
libraries/python/src/internal/tft_dataset.py (4)
140-143: Use dt.epoch('ms') for clarity and unit safety when deriving fallback timestampsCasting to Int64 and dividing by 1,000 relies on the engine’s default time unit.
dt.epoch('ms')makes the intent explicit and robust.- .cast(pl.Int64) - .floordiv(1000) + .dt.epoch("ms")
165-165: Good: validate with schema before scalingValidating prior to scaling prevents dirty data from contaminating the scaler. Consider
strict=True(see next comment) if you want to forbid extra columns creeping in.
330-382: Tighten schema: set strict=True and align with upstream equity schema to avoid drift
- Add
strict=Trueto reject unexpected columns.- Optional: add
ordered=Falseexplicitly for clarity (defaults are fine).- Ensure this schema stays in sync with any upstream
equity_bar_schemato avoid silent divergence in constraints.Proposed minimal change:
-dataset_schema = pa.DataFrameSchema( - { +dataset_schema = pa.DataFrameSchema( + { # columns... - } -) + }, + strict=True, +)If upstream uses different dtypes (e.g.,
volumeas float), decide on one canonical schema and convert at the ingestion boundary.
86-101: Scope of backfilling: avoid creating rows outside each ticker’s active date rangeUsing global min/max dates and a cross join will create rows for dates where a ticker didn’t exist/list—later filled by forward-fill/zeros—introducing long tails of synthetic data. Prefer per-ticker date ranges:
- Compute per-ticker min/max dates.
- Join each ticker to its own date range.
If you’d like, I can draft a per-ticker
date_rangeconstruction using Polars’ group-wise operations.Also applies to: 97-101, 123-123
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
applications/portfoliomanager/src/portfoliomanager/preprocess.py(1 hunks)libraries/python/src/internal/tft_dataset.py(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- applications/portfoliomanager/src/portfoliomanager/preprocess.py
🧰 Additional context used
🧬 Code graph analysis (1)
libraries/python/src/internal/tft_dataset.py (2)
libraries/python/src/internal/dataset.py (3)
TemporalFusionTransformerDataset(28-321)__init__(29-188)Scaler(10-25)libraries/python/tests/test_dataset.py (1)
test_dataset_load_data(5-45)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Run Python tests
🔇 Additional comments (4)
libraries/python/src/internal/tft_dataset.py (4)
1-3: Imports look good; Polars + Pandera(polars) + date type are correctly wiredNothing to change here; the
dateimport is needed for the schema andpandera.polarsis the right module.
30-31: Nice: cloning input avoids side effectsCloning
dataup front prevents subtle caller mutations. Good call.
147-154: Confirm Zero-Basedday_of_weekRepresentationI’ve audited all occurrences of
dt.weekday()and theday_of_weekcolumn in the repo—both reside intft_dataset.py:
- libraries/python/src/internal/tft_dataset.py
• Lines 103–107: computes a temporary weekday viapl.col("date").dt.weekday().alias("temporary_weekday")
• Lines 147–152: assignsday_of_week = pl.col("date").dt.weekday().alias("day_of_week")No offsets (
+1) or alternative conventions were found elsewhere. If your downstream embeddings or persisted models expect Monday = 1…Sunday = 7, please adjust these todt.weekday() + 1. Otherwise, consider adding a brief doc comment (in the function docstring or module README) to clarify thatday_of_weekis zero-based, ensuring consistency across training and inference.
156-163: time_idx indexing is acceptable as-isAfter reviewing the downstream usage of
time_idx—sorting by it in the batching logic (around line 295) and declaring it in the schema (around line 380)—there are no existing consumers expecting a 0-based index. If you prefer 0-based numbering for array-style consistency, you can adjust the rank:- .rank("dense") + .rank("dense") - 1However, since no parts of the code assume 0-based indexing, the current 1-based implementation is fine.

Overview
Changes
Comments
I'll manually test this one and these minimum limits are likely gonna be updated. UPDATE: this has been manually tested and trims the training data from 5,829 tickers to 1,453.
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Chores